import json

from fastapi import WebSocket, WebSocketDisconnect
from typing import Union
import asyncio
import aioredis
from app.conf import settings

redis_conf = {
    'host': settings.REDIS_HOST,
    'port': settings.REDIS_PORT,
    'password': settings.REDIS_PASSWORD,
    'db': 0,
}


# def get_redis_client():

redis_client = aioredis.from_url(
        "redis://{}".format(redis_conf['host']), db=redis_conf['db'], password=redis_conf['password'],
        port=redis_conf['port'], encoding="utf-8", decode_responses=True
    )
    # return redis_client

# redis_client_send = redis.StrictRedis(**redis_conf)
# pubsub = redis_client.pubsub()


async def publish_message(topic: str, message: Union[str, dict], message_type: str = "default"):
    pool = aioredis.from_url(
        "redis://{}".format(redis_conf['host']), db=redis_conf['db'], password=redis_conf['password'],
        port=redis_conf['port'], encoding="utf-8", decode_responses=True
    )
    # psub = pool.pubsub()
    send_msg = json.dumps({
        "message": message,
        "message_type": message_type
    })
    await pool.publish(topic, send_msg)
    # ret = await redis_client.publish(topic, data)
    # print(topic, send_msg)


class ConnectionManager:
    def __init__(self):
        # 保存当前所有的链接的websocket对象
        self.websocket_connections = {}

    async def connect(self, websocket: WebSocket, client_id):
        # 添加连接并发送欢迎消息
        await websocket.accept()
        self.websocket_connections[client_id] = websocket
        # await websocket.send_json({"type": "system",
        #                            "msg": "Welcome!",
        #                            "sender": "system",
        #                            })

        loop = asyncio.get_event_loop()
        loop.create_task(register_pubsub(client_id))

        try:
            # 处理消息
            while True:
                # 获取信息
                # websocket.receive()
                message = await websocket.receive_text()
                # 处理发送信息
                await self.handle_websocket_message(message, client_id)

        except WebSocketDisconnect:
            # 连接断开时移除连接
            del self.websocket_connections[client_id]

    async def handle_websocket_message(self, message: Union[dict, str], client_id: str, message_type: str = "default"):

        # 处理消息
        # if message.get("type") == "private_message":
        #     recipient = message.get("recipient")
        #     msg = message.get("msg")
        recipient_conn = self.websocket_connections.get(client_id)

        if recipient_conn:
            # 在线
            await recipient_conn.send_json({"type": message_type,
                                            "sender": client_id,
                                            "msg": message,
                                            })

    async def broadcast(self, message: dict):
        # 循环变量给所有在线激活的链接发送消息-全局广播
        for connection in self.websocket_connections:
            await connection.send_text(message)

    async def close(self, websocket: WebSocket, client_id):
        # 断开客户端的链接
        await websocket.close()
        del self.websocket_connections[client_id]

    async def disconnect(self, user_id):
        websocket: WebSocket = self.websocket_connections[user_id]
        await websocket.close()
        del self.websocket_connections[user_id]


websocket_manager = ConnectionManager()


async def reader(channel, client_id):
    # 进行消息的消费
    async for msg in channel.listen():  # 监听通道
        # print(msg)
        msg_data = msg.get("data")
        if msg_data and isinstance(msg_data, str):
            msg_data_dict = json.loads(msg_data)
            message = msg_data_dict.get('message', {})
            message_type = msg_data_dict.get('message_type', 'default')
            # print(f"chat:{msg_data_dict}")
            # sender = msg_data_dict.get("sender")
            # sender = str(msg_data_dict)
            # 进行消息处理
            await websocket_manager.handle_websocket_message(message, client_id, message_type)


async def register_pubsub(client_id):
    pool = redis_client
    psub = pool.pubsub()

    async with psub as p:
        # 消息订阅
        await p.subscribe(client_id)
        await reader(p, client_id)
        await p.unsubscribe(client_id)


